'''
Copyright 2020 Sensetime X-lab. All Rights Reserved

Main Function:
    1. Implementation for action_type_head, including basic processes.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from ctools.torch_utils import ResFCBlock, fc_block, one_hot
from ctools.torch_utils import CategoricalPdPytorch
from distar.model.alphastar.module_utils import build_activation


class ActionTypeHead(nn.Module):
    r"""
        Overview:
            The action type head uses lstm_output and scalar_context to get
            action_type_logits, action_type and its autoregressive_embedding.
        Interface:
            __init__, forward
    """

    def __init__(self, cfg):
        '''
            Overview: initialize architect.
            Arguments:
                - cfg (:obj:`dict`): head architecture definition
        '''
        super(ActionTypeHead, self).__init__()
        self.cfg = cfg
        self.act = build_activation(cfg.activation)  # use relu as default
        self.project = fc_block(cfg.input_dim, cfg.res_dim)
        # self.project = fc_block(cfg.input_dim, cfg.res_dim)
        blocks = [ResFCBlock(cfg.res_dim, self.act, cfg.norm_type) for _ in
                  range(cfg.res_num)]
        self.res = nn.Sequential(*blocks)
        self.weight_norm = cfg.get('weight_norm', False)
        self.drop_Z = torch.nn.Dropout(p=cfg.get('drop_ratio', 0.0))
        self.drop_ratio = cfg.get('drop_ratio', 0.0)
        self.action_fc = build_activation('glu')(cfg.res_dim, cfg.action_num, cfg.context_dim)

        self.action_map_fc = fc_block(cfg.action_num, cfg.action_map_dim, activation=self.act, norm_type=None)
        self.pd = CategoricalPdPytorch
        self.glu1 = build_activation('glu')(cfg.action_map_dim, cfg.gate_dim, cfg.context_dim)
        self.glu2 = build_activation('glu')(cfg.input_dim, cfg.gate_dim, cfg.context_dim)
        self.action_num = cfg.action_num

    def forward(self, lstm_output, scalar_context, action_type_mask, temperature=1.0, action_type=None):
        '''
            Overview: This head embeds lstm_output into a 1D tensor of size 256, passes it through 16 ResBlocks
                      with layer normalization each of size 256, and applies a ReLU. The output is converted to
                      a tensor with one logit for each possible action type through a GLU gated by scalar_context.
                      action_type is sampled from these logits using a multinomial with temperature 0.8. Note that
                      during supervised learning, action_type will be the ground truth human action type, and
                      temperature is 1.0 (and similarly for all other arguments).
                      autoregressive_embedding is then generated by first applying a ReLU and linear layer of
                      size 256 to the one-hot version of action_type, and projecting it to a 1D tensor of size 1024
                      through a GLU gated by scalar_context. That projection is added to another projection of
                      lstm_output into a 1D tensor of size 1024 gated by scalar_context to yield autoregressive_embedding.  # noqa
            Arguments:
                - lstm_output (:obj:`tensor`): The output of the LSTM
                - scalar_context (:obj:`tensor`): A 1D tensor of certain scalar features, include available_actions,
                                                  cumulative_statistics, beginning_build_order
                - action_type_mask (:obj:`tensor`): 0-1 value tensor contains available action type
                - temperature (:obj:`float`): sampling temperature for action in case action input is None
                - action_type (:obj:`tensor`): Action type, of size [1]
            Returns:
                - (:obj`tensor`): action_type_logits corresponding to the probabilities of taking each action
                - (:obj`tensor`): action_type sampled from the action_type_logits
                - (:obj`tensor`): autoregressive_embedding that combines information from lstm_output
                                  and all previous sampled arguments.
        '''
        x = self.project(lstm_output)  # embeds lstm_output into a 1D tensor of size of res_dim, use 256 as default
        x = self.res(x)  # passes x through 16 ResBlocks with layer normalization and ReLU
        # drop_scalar_context = self.drop_Z(scalar_context)
        x = self.action_fc(x, scalar_context)  # fc for action type without normalization
        # TODO(nyz) show warning info about the masked action_type label
        # action_type_mask is used as inputs in some network parts, so we must detach it from graph
        action_type_mask = action_type_mask.clone().detach()
        if action_type is not None:
            for i, a in enumerate(action_type):
                action_type_mask[i, a.item()] = 1
        if self.cfg.use_mask:
            x -= (1 - action_type_mask) * 1e9
        x = x.div(0.8)
        if action_type is None:
            p = F.softmax(x, dim=1)
            handle = self.pd(p)
            action_type = handle.sample()

        # to get autoregressive_embedding
        action_one_hot = one_hot(action_type, self.action_num)  # one-hot version of action_type
        embedding1 = self.action_map_fc(action_one_hot)
        embedding1 = self.glu1(embedding1, scalar_context)
        embedding2 = self.glu2(lstm_output, scalar_context)
        embedding = embedding1 + embedding2

        return x, action_type, embedding
